<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
    <title>WebGPU Compute Shaders - Histogram, optimized more</title>
    <style>
      @import url(resources/webgpu-lesson.css);
      canvas {
        display: block;
        max-width: 256px;
        border: 1px solid #888;
        background-color: #333;
      }
    </style>
  </head>
  <body>
  </body>
  <script type="module">
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-utils.html#webgpu-utils
import {
  loadImageBitmap,
  createTextureFromSource,
} from '../3rdparty/webgpu-utils-1.x.module.js';

async function main() {
  const adapter = await navigator.gpu?.requestAdapter();
  const device = await adapter?.requestDevice();
  if (!device) {
    fail('need a browser that supports WebGPU');
    return;
  }

  const k = {
    chunkWidth: 256,
    chunkHeight: 1,
  };
  const chunkSize = k.chunkWidth * k.chunkHeight;
  const sharedConstants = Object.entries(k).map(([k, v]) => `const ${k} = ${v};`).join('\n');

  const histogramChunkModule = device.createShaderModule({
    label: 'histogram chunk shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;
      var<workgroup> bins: array<array<atomic<u32>, 4>, chunkSize>;
      @group(0) @binding(0) var<storage, read_write> chunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var ourTexture: texture_2d<f32>;

      const kSRGBLuminanceFactors = vec3f(0.2126, 0.7152, 0.0722);
      fn srgbLuminance(color: vec3f) -> f32 {
        return saturate(dot(color, kSRGBLuminanceFactors));
      }

      @compute @workgroup_size(chunkWidth, chunkHeight, 1)
      fn cs(
        @builtin(workgroup_id) workgroup_id: vec3u,
        @builtin(local_invocation_id) local_invocation_id: vec3u,
      ) {
        let size = textureDimensions(ourTexture, 0);
        let position = workgroup_id.xy * vec2u(chunkWidth, chunkHeight) + 
                       local_invocation_id.xy;
        if (all(position < size)) {
          let numBins = f32(chunkSize);
          let lastBinIndex = u32(numBins - 1);
          var channels = textureLoad(ourTexture, position, 0);
          channels.w = srgbLuminance(channels.rgb);
          for (var ch = 0; ch < 4; ch++) {
            let v = channels[ch];
            let bin = min(u32(v * numBins), lastBinIndex);
            atomicAdd(&bins[bin][ch], 1u);
          }
        }

        workgroupBarrier();

        let chunksAcross = (size.x + chunkWidth - 1) / chunkWidth;
        let chunk = workgroup_id.y * chunksAcross + workgroup_id.x;
        let bin = local_invocation_id.y * chunkWidth + local_invocation_id.x;

        chunks[chunk][bin] = vec4u(
          atomicLoad(&bins[bin][0]),
          atomicLoad(&bins[bin][1]),
          atomicLoad(&bins[bin][2]),
          atomicLoad(&bins[bin][3]),
        );
      }
    `,
  });

  const chunkSumModule = device.createShaderModule({
    label: 'chunk sum shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;

      struct Uniforms {
        stride: u32,
      };

      @group(0) @binding(0) var<storage, read_write> chunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var<uniform> uni: Uniforms;

      @compute @workgroup_size(chunkSize, 1, 1) fn cs(
        @builtin(local_invocation_id) local_invocation_id: vec3u,
        @builtin(workgroup_id) workgroup_id: vec3u,
      ) {
        let chunk0 = workgroup_id.x * uni.stride * 2;
        let chunk1 = chunk0 + uni.stride;

        let sum = chunks[chunk0][local_invocation_id.x] +
                  chunks[chunk1][local_invocation_id.x];
        chunks[chunk0][local_invocation_id.x] = sum;
      }
    `,
  });

  const histogramChunkPipeline = device.createComputePipeline({
    label: 'histogram',
    layout: 'auto',
    compute: {
      module: histogramChunkModule,
    },
  });

  const chunkSumPipeline = device.createComputePipeline({
    label: 'chunk sum',
    layout: 'auto',
    compute: {
      module: chunkSumModule,
    },
  });

  const imgBitmap = await loadImageBitmap('resources/images/pexels-chevanon-photography-1108099.jpg'); /* webgpufundamentals: url */
  const texture = createTextureFromSource(device, imgBitmap);

  const chunksAcross = Math.ceil(texture.width / k.chunkWidth);
  const chunksDown = Math.ceil(texture.height / k.chunkHeight);
  const numChunks = chunksAcross * chunksDown;
  const chunksBuffer = device.createBuffer({
    size: numChunks * chunkSize * 4 * 4,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
  });

  const resultBuffer = device.createBuffer({
    size: chunkSize * 4 * 4,
    usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
  });

  const histogramBindGroup = device.createBindGroup({
    layout: histogramChunkPipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: chunksBuffer }},
      { binding: 1, resource: texture.createView() },
    ],
  });

  const sumBindGroups = [];
  const numSteps = Math.ceil(Math.log2(numChunks));
  for (let i = 0; i < numSteps; ++i) {
    const stride = 2 ** i;
    const uniformBuffer = device.createBuffer({
      size: 4,
      usage: GPUBufferUsage.UNIFORM,
      mappedAtCreation: true,
    });
    new Uint32Array(uniformBuffer.getMappedRange()).set([stride]);
    uniformBuffer.unmap();

    const chunkSumBindGroup = device.createBindGroup({
      layout: chunkSumPipeline.getBindGroupLayout(0),
      entries: [
        { binding: 0, resource: { buffer: chunksBuffer }},
        { binding: 1, resource: { buffer: uniformBuffer }},
      ],
    });
    sumBindGroups.push(chunkSumBindGroup);
  }

  const encoder = device.createCommandEncoder({ label: 'histogram encoder' });
  const pass = encoder.beginComputePass();

  // create a histogram for each chunk
  pass.setPipeline(histogramChunkPipeline);
  pass.setBindGroup(0, histogramBindGroup);
  pass.dispatchWorkgroups(chunksAcross, chunksDown);

  // reduce the chunks
  pass.setPipeline(chunkSumPipeline);
  let chunksLeft = numChunks;
  sumBindGroups.forEach(bindGroup => {
    pass.setBindGroup(0, bindGroup);
    const dispatchCount = Math.floor(chunksLeft / 2);
    chunksLeft -= dispatchCount;
    pass.dispatchWorkgroups(dispatchCount);
  });
  pass.end();

  encoder.copyBufferToBuffer(chunksBuffer, 0, resultBuffer, 0, resultBuffer.size);

  const commandBuffer = encoder.finish();
  device.queue.submit([commandBuffer]);

  await resultBuffer.mapAsync(GPUMapMode.READ);
  const histogram = new Uint32Array(resultBuffer.getMappedRange());

  showImageBitmap(imgBitmap);

  // draw the red, green, and blue channels
  const numEntries = texture.width * texture.height;
  drawHistogram(histogram, numEntries, [0, 1, 2]);

  // draw the luminosity channel
  drawHistogram(histogram, numEntries, [3]);

  resultBuffer.unmap();
}

function drawHistogram(histogram, numEntries, channels, height = 100) {
  // find the highest value for each channel
  const numBins = histogram.length / 4;
  const max = [0, 0, 0, 0];
  histogram.forEach((v, ndx) => {
    const ch = ndx % 4;
    max[ch] = Math.max(max[ch], v);
  });
  const scale = max.map(max => Math.max(1 / max, 0.2 * numBins / numEntries));

  const canvas = document.createElement('canvas');
  canvas.width = numBins;
  canvas.height = height;
  document.body.appendChild(canvas);
  const ctx = canvas.getContext('2d');

  const colors = [
    'rgb(255, 0, 0)',
    'rgb(0, 255, 0)',
    'rgb(0, 0, 255)',
    'rgb(255, 255, 255)',
  ];

  ctx.globalCompositeOperation = 'screen';

  for (let x = 0; x < numBins; ++x) {
    const offset = x * 4;
    for (const ch of channels) {
      const v = histogram[offset + ch] * scale[ch] * height;
      ctx.fillStyle = colors[ch];
      ctx.fillRect(x, height - v, 1, v);
    }
  }
}

function showImageBitmap(imageBitmap) {
  const canvas = document.createElement('canvas');

  // we have to see the canvas size because of a firefox bug
  // https://bugzilla.mozilla.org/show_bug.cgi?id=1850871
  canvas.width = imageBitmap.width;
  canvas.height = imageBitmap.height;

  const bm = canvas.getContext('bitmaprenderer');
  bm.transferFromImageBitmap(imageBitmap);
  document.body.appendChild(canvas);
}

function fail(msg) {
  // eslint-disable-next-line no-alert
  alert(msg);
}

main();
  </script>
</html>
